/*
Copyright (c) 2025 Huawei Technologies Co., Ltd.
This file is a part of the CANN Open Software.
Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
Please refer to the License for details. You may not use this file except in compliance with the License.
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
See LICENSE in the root of the software repository for the full text of the License. 
*/

#include "acl/acl.h"
#include "../../include/custom_type.h"
#include "../data_utils.h"
#include "../run_main.h"
#include "./LLMsGEMM_batch_QKTVP_host.h"
#include <cstdio>
#include <string>


int main ( int argc, char ** argv ) {

    // 获取输入 deviceId layoutA layoutB zeroPaddingM zeroPaddingN zeroPaddingK batchCount
    uint32_t deviceId = std::stoi( argv[1] ); 
    layoutType layoutA = ( std::stoi( argv[2] ) == 0 ? RowMajor : ColumnMajor );
    layoutType layoutB = ( std::stoi( argv[3] ) == 0 ? RowMajor : ColumnMajor );
    uint32_t zeroPaddingM = std::stoi( argv[4] );
    uint32_t zeroPaddingN = std::stoi( argv[5] );
    uint32_t zeroPaddingK = std::stoi( argv[6] );
    uint32_t batchCount = std::stoi( argv[7] );  
    
    // 打印输出 layoutA layoutB zeroPaddingM zeroPaddingN zeroPaddingK batchCount
    printf("\nTesting kernel on device %d. \n"
            "Getting test input: \n"
            "layoutA: %s, layoutB: %s, "
            "zeroPaddingM: %d, zeroPaddingN: %d, zeroPaddingK: %d, " 
            "batchCount: %d. \n", 
            deviceId, 
            layoutA == 0 ? "RowMajor" : "ColumnMajor", layoutB == 0 ? "RowMajor" : "ColumnMajor", 
            zeroPaddingM, zeroPaddingN, zeroPaddingK, 
            batchCount);

    if(deviceId < 0 || deviceId > 7 
        || !(layoutA == RowMajor || layoutA == ColumnMajor)
        || !(layoutB == RowMajor || layoutB == ColumnMajor)
        || zeroPaddingM <= 0
        || zeroPaddingN <= 0
        || zeroPaddingK <= 0
        || batchCount <= 0){
        printf("Wrong input! \n"); 
        return 0; 
    }

    // acl初始化
    const char *aclConfigPath = "../acl.json";
    ACL_CHECK(aclInit(/*aclConfigPath*/nullptr));
    ACL_CHECK(aclrtSetDevice(deviceId));
    aclrtStream stream;
    ACL_CHECK(aclrtCreateStream(&stream));

    std::string src="./data/";
    std::string pathA = src + "A.bin";
    std::string pathB = src + "B.bin";
    std::string pathC = src + "C.bin";
    std::string pathAlpha = src + "alpha.bin";
    std::string pathBeta  = src + "beta.bin";
    std::string pathMaskA = src + "maskA.bin";
    std::string pathExpectResult = src + "expect_result.bin";

    half* preOutput = nullptr;
    half* curOutput = nullptr;

    run_main(
        LLMsGEMM_batch_QKTVP_host,
        stream, 
        layoutA, 
        layoutB, 
        zeroPaddingM, 
        zeroPaddingN, 
        zeroPaddingK, 
        batchCount, 
        pathA, 
        pathB, 
        pathC, 
        pathAlpha, 
        pathBeta, 
        pathMaskA, 
        pathExpectResult, 
        preOutput, 
        curOutput, 
        3, 
        0
    );

    ACL_CHECK(aclrtFree(curOutput));

    // 反初始化
    ACL_CHECK(aclrtDestroyStream(stream));
    ACL_CHECK(aclrtResetDevice(deviceId));
    ACL_CHECK(aclFinalize());

    return 0;
}